package main

import (
	"database/sql"
	"flag"
	"fmt"
	_ "github.com/go-sql-driver/mysql"
	"log"
	"os"
	"os/exec"
	"path"
	"sort"
	"strings"
)

var conn *sql.DB
var err error

func main() {

	host := flag.String("host", "127.0.0.1", "-h=127.0.0.1")
	port := flag.String("port", "3306", "-p=3306")
	user := flag.String("user", "root", "-user=root")
	password := flag.String("password", "", "-password=password")
	db := flag.String("db", "test", "-db")
	out := flag.String("out", "models/model.go", "-out=models/model.go")
	flag.Parse()
	dsn := fmt.Sprintf("%s:%s@tcp(%s:%s)/%s?charset=utf8mb4&parseTime=True&loc=Local",
		*user, *password, *host, *port, *db)
	conn, err = sql.Open("mysql", dsn)
	if err != nil {
		panic(err)
	}

	//检测输出目录
	var f *os.File
	dir, _ := parsePath(*out)
	if dir != "" {
		err = os.MkdirAll(dir, os.ModePerm)
		if err != nil {
			log.Println("创建目录失败:" + err.Error())
			return
		}
	}
	f, err = os.OpenFile(*out, os.O_RDWR|os.O_CREATE|os.O_TRUNC, 0666)
	if err != nil {
		log.Println(err)
		return
	}
	packageName := getPackageName(dir)
	defer f.Close()

	getAllTables(*db)
	getColumns(*db)
	content := generateModel(packageName)

	f.WriteString(content)
	cmd := exec.Command("go fmt", *out)
	cmd.Run()
}

type column struct {
	Name      string
	DataType  string
	Comment   string
	TableName string
}
type Table struct {
	Name    string   `json:"name"`
	Comment string   `json:"comment"`
	Columns []column `json:"columns"`
}

var tableSet = make(map[string]*Table)

func parsePath(out string) (string, string) {
	dir, file := path.Split(out)
	if dir == "" || dir == "./" {
		dir = ""
	}
	return dir, file
}

// 获取所有表名
func getAllTables(database string) {
	rows, err := conn.Query("select TABLE_NAME,TABLE_COMMENT from information_schema.TABLES where TABLE_SCHEMA = ? and TABLE_TYPE = ?", database, "BASE TABLE")
	if err != nil {
		panic(err)
	}
	for rows.Next() {
		var name string
		var comment string
		rows.Scan(&name, &comment)
		tableSet[name] = &Table{
			Name:    name,
			Comment: comment,
			Columns: make([]column, 0),
		}
	}

}

func checkAndMkdir(dir string) {

}

// 获取列名
func getColumns(database string) {
	rows, err := conn.Query("SELECT COLUMN_NAME,DATA_TYPE,TABLE_NAME,COLUMN_COMMENT FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = ? order by TABLE_SCHEMA asc,ORDINAL_POSITION asc", database)
	if err != nil {
		panic(err)
	}
	for rows.Next() {

		var col column
		rows.Scan(&col.Name, &col.DataType, &col.TableName, &col.Comment)
		if v, ok := tableSet[col.TableName]; ok {
			v.Columns = append(v.Columns, col)
		}
	}

}

func generateModel(packageName string) string {
	funcstr := `
func (*%s)TableName() string {
	return "%s"
}
`
	arr := make([]*Table, 0)
	for _, v := range tableSet {
		arr = append(arr, v)
	}
	sort.SliceStable(arr, func(i, j int) bool {
		return arr[i].Name < arr[j].Name
	})
	builder := strings.Builder{}
	builder.WriteString(fmt.Sprintf("package %s\n\n", packageName))

	content := ""
	for _, t := range arr {
		// 表名注释，大驼峰
		content += fmt.Sprintf("//%s %s\n", snake2BigHump(t.Name), t.Comment)
		content += fmt.Sprintf("type %s struct {\n", snake2BigHump(t.Name))
		//id int json
		for _, col := range t.Columns {
			if col.Comment != "" {
				content += fmt.Sprintf("\t// %s\n", col.Comment)
			}
			content += fmt.Sprintf("\t%s\t%s\t`json:\"%s\"`\n",
				snake2BigHump(col.Name), typeForMysqlToGo[col.DataType], snake2LittleHump(col.Name))
		}
		content += "}\n"
		content += fmt.Sprintf(funcstr, snake2BigHump(t.Name), t.Name)
		content += "\n\n\n"
	}
	if strings.Index(content, "time.Time") > -1 {
		builder.WriteString("import \"time\"\n\n")
	}
	builder.WriteString(content)

	return builder.String()
}

func getPackageName(dir string) string {
	if dir == "" {
		dir, _ = os.Getwd()
	}
	dir = strings.TrimRight(dir, "/")
	n := strings.LastIndex(dir, "/")
	return dir[n+1:]

}

// 大驼峰
func snake2BigHump(str string) string {
	arr := strings.Split(str, "_")
	text := ""
	for _, v := range arr {
		text += strings.ToUpper(v[:1]) + v[1:]
	}
	return text
}

// 大驼峰
func snake2LittleHump(str string) string {
	arr := strings.Split(str, "_")
	text := ""
	for k, v := range arr {
		if k == 0 {
			text += strings.ToLower(v[:1]) + v[1:]
		} else {
			text += strings.ToUpper(v[:1]) + v[1:]
		}
	}
	return text
}

var typeForMysqlToGo = map[string]string{
	"int":                "int",
	"integer":            "int",
	"tinyint":            "int",
	"smallint":           "int",
	"mediumint":          "int",
	"bigint":             "int",
	"int unsigned":       "int",
	"integer unsigned":   "int",
	"tinyint unsigned":   "int",
	"smallint unsigned":  "int",
	"mediumint unsigned": "int",
	"bigint unsigned":    "int",
	"bit":                "int",
	"bool":               "bool",
	"enum":               "string",
	"set":                "string",
	"varchar":            "string",
	"char":               "string",
	"tinytext":           "string",
	"mediumtext":         "string",
	"text":               "string",
	"longtext":           "string",
	"blob":               "string",
	"tinyblob":           "string",
	"mediumblob":         "string",
	"longblob":           "string",
	"date":               "string",    // time.Time
	"datetime":           "time.Time", // time.Time
	"timestamp":          "time.Time", // time.Time
	"time":               "string",    // time.Time
	"float":              "float64",
	"double":             "float64",
	"decimal":            "float64",
	"binary":             "string",
	"varbinary":          "string",
}
